from typing import Callable

import taichi as ti

from ..shape import Shape
from .common import Vec3, Inf, Ray3
from ..config import MAX_RAY_MARCH

__all__ = [
    "SDF",
]


class SDF:
    @classmethod
    def sphere(r: float = 1):
        def sdf(p: Vec3) -> float:
            return p.norm() - r

        return SDF.compile(sdf)

    @classmethod
    def compile(cls, sdf: Callable[[Vec3], float]):
        sdf = ti.func(sdf)

        class SDF(Shape):
            @ti.func
            def getNormalLine(self: int, point: Vec3) -> Vec3:
                n, h = Vec3(0), 0.5773 * 0.005
                for i in ti.static(range(4)):
                    e = 2 * Vec3(((i + 3) >> 1) & 1, (i >> 1) & 1, i & 1) - 1
                    n += e * sdf(point + e * h)
                return n.normalized()

            @ti.func
            def intersect(self: int, ray: Ray3, closest: float) -> float:
                dist, hit = 0.0, False
                norm = Vec3(0)
                for _ in range(MAX_RAY_MARCH):
                    dist_ = abs(sdf(ray.at(dist)))
                    if dist_ < 1e-5:
                        hit = True
                        break
                    dist += dist_
                    if dist >= closest:
                        break
                if hit:
                    norm = SDF.getNormalLine(self, ray.at(dist))
                else:
                    dist = Inf
                return dist, norm

        return SDF
